import os
from dataset import create_dataset
from model.UNet import UNet
from utils.engine import DiffusionSDETrainer  # Replaced with Diffusion model trainer
from utils.tools import train_one_epoch, load_yaml
import torch
from utils.callbacks import ModelCheckpoint
from eigen_sde import DiffusionSDEPolicy  # Importing the diffusion model policy

def train(config):
    consume = config["consume"]
    if consume:
        cp = torch.load(config["consume_path"])
        config = cp["config"]
    print(config)

    device = torch.device(config["device"])
    loader = create_dataset(**config["Dataset"])
    start_epoch = 1

    # Initialize the UNet model
    model = UNet(**config["Model"]).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=1e-4)

    # Replace with DiffusionSDEPolicy
    diffusion_policy = DiffusionSDEPolicy(state_dim=config["Model"]["state_dim"],
                                          action_dim=config["Model"]["action_dim"],
                                          hidden_dim=config["Model"]["hidden_dim"],
                                          T=config["Model"]["T"],
                                          max_action=config["Model"]["max_action"]).to(device)

    # Replace with diffusion model trainer
    trainer = DiffusionSDETrainer(model, diffusion_policy, **config["Trainer"]).to(device)

    model_checkpoint = ModelCheckpoint(**config["Callback"])

    if consume:
        model.load_state_dict(cp["model"])
        optimizer.load_state_dict(cp["optimizer"])
        model_checkpoint.load_state_dict(cp["model_checkpoint"])
        diffusion_policy.load_state_dict(cp["diffusion_policy"])  # Load diffusion model state
        start_epoch = cp["start_epoch"] + 1

    # Set the directory path for saving checkpoints
    checkpoint_dir = '/idas/users/liusirui/0926/k-3/checkpoint/step750/'
    os.makedirs(checkpoint_dir, exist_ok=True)

    for epoch in range(start_epoch, config["epochs"] + 1):
        # Update the training process to use the diffusion model trainer
        loss = train_one_epoch(trainer, loader, optimizer, device, epoch)
        
        # Save model checkpoint
        model_checkpoint.step(loss, model=model.state_dict(), config=config,
                              optimizer=optimizer.state_dict(), start_epoch=epoch,
                              model_checkpoint=model_checkpoint.state_dict(), 
                              diffusion_policy=diffusion_policy.state_dict())  # Save diffusion model state

        if epoch % 50 == 0:
            save_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pth")
            torch.save({
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "model_checkpoint": model_checkpoint.state_dict(),
                "config": config,
                "start_epoch": epoch,
                "diffusion_policy": diffusion_policy.state_dict()  # Save diffusion model state
            }, save_path)
            print(f"Checkpoint saved at epoch {epoch} to {save_path}")

if __name__ == "__main__":
    config = load_yaml("config.yml", encoding="utf-8")
    train(config)
